{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import re\n",
    "import collections\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def simple_textcleaning(string):\n",
    "    string = re.sub('[^A-Za-z ]+', ' ', string)\n",
    "    return re.sub(r'[ ]+', ' ', string.lower()).strip()\n",
    "\n",
    "def batch_sequence(sentences, dictionary, maxlen = 50):\n",
    "    np_array = np.zeros((len(sentences), maxlen), dtype = np.int32)\n",
    "    for no_sentence, sentence in enumerate(sentences):\n",
    "        current_no = 0\n",
    "        for no, word in enumerate(sentence.split()[: maxlen - 2]):\n",
    "            np_array[no_sentence, no] = dictionary.get(word, 1)\n",
    "            current_no = no\n",
    "        np_array[no_sentence, current_no + 1] = 3\n",
    "    return np_array\n",
    "\n",
    "def counter_words(sentences):\n",
    "    word_counter = collections.Counter()\n",
    "    word_list = []\n",
    "    num_lines, num_words = (0, 0)\n",
    "    for i in sentences:\n",
    "        words = re.findall('[\\\\w\\']+|[;:\\-\\(\\)&.,!?\"]', i)\n",
    "        word_counter.update(words)\n",
    "        word_list.extend(words)\n",
    "        num_lines += 1\n",
    "        num_words += len(words)\n",
    "    return word_counter, word_list, num_lines, num_words\n",
    "\n",
    "\n",
    "def build_dict(word_counter, vocab_size = 50000):\n",
    "    count = [['PAD', 0], ['UNK', 1], ['START', 2], ['END', 3]]\n",
    "    count.extend(word_counter.most_common(vocab_size))\n",
    "    dictionary = dict()\n",
    "    for word, _ in count:\n",
    "        dictionary[word] = len(dictionary)\n",
    "    return dictionary, {word: idx for idx, word in dictionary.items()}\n",
    "\n",
    "def split_by_dot(string):\n",
    "    string = re.sub(\n",
    "        r'(?<!\\d)\\.(?!\\d)',\n",
    "        'SPLITTT',\n",
    "        string.replace('\\n', '').replace('/', ' '),\n",
    "    )\n",
    "    string = string.split('SPLITTT')\n",
    "    return [re.sub(r'[ ]+', ' ', sentence).strip() for sentence in string]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9923"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "contents = []\n",
    "with open('books/Blood_Born') as fopen:\n",
    "    contents.extend(split_by_dot(fopen.read()))\n",
    "    \n",
    "with open('books/Dark_Thirst') as fopen:\n",
    "    contents.extend(split_by_dot(fopen.read()))\n",
    "    \n",
    "len(contents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8390"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "contents = [simple_textcleaning(sentence) for sentence in contents]\n",
    "contents = [sentence for sentence in contents if len(sentence) > 20]\n",
    "len(contents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9039"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "maxlen = 50\n",
    "vocabulary_size = len(set(' '.join(contents).split()))\n",
    "embedding_size = 256\n",
    "learning_rate = 1e-3\n",
    "batch_size = 16\n",
    "vocabulary_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.utils import shuffle\n",
    "\n",
    "stride = 1\n",
    "t_range = int((len(contents) - 3) / stride + 1)\n",
    "left, middle, right = [], [], []\n",
    "for i in range(t_range):\n",
    "    slices = contents[i * stride : i * stride + 3]\n",
    "    left.append(slices[0])\n",
    "    middle.append(slices[1])\n",
    "    right.append(slices[2])\n",
    "\n",
    "left, middle, right = shuffle(left, middle, right)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "word_counter, _, _, _ = counter_words(middle)\n",
    "dictionary, _ = build_dict(word_counter, vocab_size = vocabulary_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dict_size,\n",
    "        size_layers,\n",
    "        learning_rate,\n",
    "        maxlen,\n",
    "        num_blocks = 3,\n",
    "    ):\n",
    "        block_size = size_layers\n",
    "        self.BEFORE = tf.placeholder(tf.int32,[None,maxlen])\n",
    "        self.INPUT = tf.placeholder(tf.int32,[None,maxlen])\n",
    "        self.AFTER = tf.placeholder(tf.int32,[None,maxlen])\n",
    "        self.batch_size = tf.shape(self.INPUT)[0]\n",
    "        self.output_layer = tf.layers.Dense(dict_size, name=\"output_layer\")\n",
    "        self.output_layer.build(size_layers)\n",
    "        self.embeddings = tf.Variable(tf.random_uniform([dict_size, size_layers], -1, 1))\n",
    "        embedded = tf.nn.embedding_lookup(self.embeddings, self.INPUT)\n",
    "\n",
    "        def residual_block(x, size, rate, block, reuse = False):\n",
    "            with tf.variable_scope(\n",
    "                'block_%d_%d' % (block, rate), reuse = reuse\n",
    "            ):\n",
    "                conv_filter = tf.layers.conv1d(\n",
    "                    x,\n",
    "                    x.shape[2] // 4,\n",
    "                    kernel_size = size,\n",
    "                    strides = 1,\n",
    "                    padding = 'same',\n",
    "                    dilation_rate = rate,\n",
    "                    activation = tf.nn.tanh,\n",
    "                )\n",
    "                conv_gate = tf.layers.conv1d(\n",
    "                    x,\n",
    "                    x.shape[2] // 4,\n",
    "                    kernel_size = size,\n",
    "                    strides = 1,\n",
    "                    padding = 'same',\n",
    "                    dilation_rate = rate,\n",
    "                    activation = tf.nn.sigmoid,\n",
    "                )\n",
    "                out = tf.multiply(conv_filter, conv_gate)\n",
    "                out = tf.layers.conv1d(\n",
    "                    out,\n",
    "                    block_size,\n",
    "                    kernel_size = 1,\n",
    "                    strides = 1,\n",
    "                    padding = 'same',\n",
    "                    activation = tf.nn.tanh,\n",
    "                )\n",
    "                return tf.add(x, out), out\n",
    "\n",
    "        forward = tf.layers.conv1d(\n",
    "            embedded, block_size, kernel_size = 1, strides = 1, padding = 'SAME'\n",
    "        )\n",
    "        zeros = tf.zeros_like(forward)\n",
    "        for i in range(num_blocks):\n",
    "            for r in [1, 2, 4, 8, 16]:\n",
    "                forward, s = residual_block(\n",
    "                    forward, size = 7, rate = r, block = i\n",
    "                )\n",
    "                zeros = tf.add(zeros, s)\n",
    "        forward = tf.layers.conv1d(\n",
    "            zeros,\n",
    "            block_size,\n",
    "            kernel_size = 1,\n",
    "            strides = 1,\n",
    "            padding = 'SAME',\n",
    "            activation = tf.nn.tanh,\n",
    "        )\n",
    "        self.get_thought = tf.reduce_sum(forward,axis=1, name = 'logits')\n",
    "        \n",
    "        def decoder(labels, reuse):\n",
    "            main = tf.strided_slice(labels, [0, 0], [self.batch_size, -1], [1, 1])\n",
    "            shifted_labels = tf.concat([tf.fill([self.batch_size, 1], 2), main], 1)\n",
    "            decoder_in = tf.nn.embedding_lookup(self.embeddings, shifted_labels)\n",
    "            forward = tf.layers.conv1d(\n",
    "                decoder_in, block_size, kernel_size = 1, strides = 1, padding = 'SAME'\n",
    "            )\n",
    "            zeros = tf.zeros_like(forward)\n",
    "            for r in [8, 16, 24]:\n",
    "                forward, s = residual_block(forward, size = 7, rate = r, block = 10, reuse = reuse)\n",
    "                zeros = tf.add(zeros, s)\n",
    "            return tf.layers.conv1d(\n",
    "                zeros,\n",
    "                block_size,\n",
    "                kernel_size = 1,\n",
    "                strides = 1,\n",
    "                padding = 'SAME',\n",
    "                activation = tf.nn.tanh,\n",
    "            )\n",
    "        \n",
    "        fw_logits = decoder(self.AFTER, False)\n",
    "        bw_logits = decoder(self.BEFORE, True)\n",
    "        self.attention = tf.matmul(\n",
    "            self.get_thought, tf.transpose(self.embeddings), name = 'attention'\n",
    "        )\n",
    "        self.loss = self.calculate_loss(fw_logits, self.AFTER) + self.calculate_loss(bw_logits, self.BEFORE)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.loss)\n",
    "    \n",
    "    def calculate_loss(self, outputs, labels):\n",
    "        mask = tf.cast(tf.sign(labels), tf.float32)\n",
    "        logits = self.output_layer(outputs)\n",
    "        return tf.contrib.seq2seq.sequence_loss(logits, labels, mask)\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(len(dictionary), embedding_size, learning_rate, maxlen)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 525/525 [00:26<00:00, 19.71it/s, cost=11.9]\n",
      "train minibatch loop: 100%|██████████| 525/525 [00:24<00:00, 22.81it/s, cost=10.6]\n",
      "train minibatch loop: 100%|██████████| 525/525 [00:24<00:00, 22.86it/s, cost=9.86]\n",
      "train minibatch loop: 100%|██████████| 525/525 [00:24<00:00, 22.84it/s, cost=9.1] \n",
      "train minibatch loop: 100%|██████████| 525/525 [00:24<00:00, 22.81it/s, cost=8.42]\n"
     ]
    }
   ],
   "source": [
    "for i in range(5):\n",
    "    pbar = tqdm(range(0, len(middle), batch_size), desc='train minibatch loop')\n",
    "    for p in pbar:\n",
    "        index = min(p + batch_size, len(middle))\n",
    "        batch_x = batch_sequence(\n",
    "                middle[p : index],\n",
    "                dictionary,\n",
    "                maxlen = maxlen,\n",
    "        )\n",
    "        batch_y_before = batch_sequence(\n",
    "                left[p : index],\n",
    "                dictionary,\n",
    "                maxlen = maxlen,\n",
    "        )\n",
    "        batch_y_after = batch_sequence(\n",
    "                right[p : index],\n",
    "                dictionary,\n",
    "                maxlen = maxlen,\n",
    "        )\n",
    "        loss, _ = sess.run([model.loss, model.optimizer], \n",
    "                           feed_dict = {model.BEFORE: batch_y_before,\n",
    "                                        model.INPUT: batch_x,\n",
    "                                        model.AFTER: batch_y_after,})\n",
    "        pbar.set_postfix(cost=loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('books/Driftas_Quest') as f:\n",
    "    book = f.read()\n",
    "\n",
    "book = split_by_dot(book)\n",
    "book = [simple_textcleaning(sentence) for sentence in book]\n",
    "book = [sentence for sentence in book if len(sentence) > 20][100:200]\n",
    "book_sequences = batch_sequence(book, dictionary, maxlen = maxlen)\n",
    "encoded, attention = sess.run([model.get_thought, model.attention],feed_dict={model.INPUT:book_sequences})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "drifta activated the alarm and the others were appearing for the ride home and at top speed they made it to the ship where grofinglaz was rushed to the medical bay. completely oblivious to the fact he was being observed from millions of miles away but there was something unnerving in the way he suddenly turned his head to look directly back at them. they set about their work slicing up the ore and making neat piles for the extraction team. so magnified was the viewer that they could see individual humans going about their business in a city street. that was spooky said lorsilkor. drifta had been there twice to his knowledge. chapter grofinglaz was already sitting upright in the sickbay when drifta visited. nozrendo had recorded it all then the ship was flashing through space close to the speed of the sunlight from that solar systems star. drifta took the hint and left to rest on his bunk and reflect on recent events. since a little kid i ve travelled from one planet to the next living mostly on ships and scratching a living not knowing where i came from or who my family are\n"
     ]
    }
   ],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "from sklearn.metrics import pairwise_distances_argmin_min\n",
    "\n",
    "n_clusters = 10\n",
    "kmeans = KMeans(n_clusters=n_clusters, random_state=0)\n",
    "kmeans = kmeans.fit(encoded)\n",
    "avg = []\n",
    "closest = []\n",
    "for j in range(n_clusters):\n",
    "    idx = np.where(kmeans.labels_ == j)[0]\n",
    "    avg.append(np.mean(idx))\n",
    "closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_,encoded)\n",
    "ordering = sorted(range(n_clusters), key=lambda k: avg[k])\n",
    "print('. '.join([book[closest[idx]] for idx in ordering]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Important words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['butting',\n",
       " 'flames',\n",
       " 'toffee',\n",
       " 'sigma',\n",
       " 'twice',\n",
       " 'shouldered',\n",
       " 'inspire',\n",
       " 'buttoning',\n",
       " 'reassure',\n",
       " 'perhaps']"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "indices = np.argsort(attention.mean(axis=0))[::-1]\n",
    "rev_dictionary = {v:k for k, v in dictionary.items()}\n",
    "[rev_dictionary[i] for i in indices[:10]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
